import os
import random
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

path_lenet = os.path.abspath(os.path.join(BASE_DIR, 'model', 'lenet.py'))
path_tools = os.path.abspath(os.path.join(BASE_DIR, 'tools', 'common_tools.py'))
assert os.path.exists(path_lenet), "{}不存在，请将lenet.py文件放到{}".format(path_lenet, os.path.dirname(path_lenet))
assert os.path.exists(path_tools), "{}不存在，请将common_tools.py文件放到{}".format(path_tools, os.path.dirname(path_tools))

import sys
hello_pytorch_DIR = os.path.abspath(os.path.abspath(os.path.dirname(__file__) + os.path.sep + '..'))
sys.path.append(hello_pytorch_DIR)

from deepeye.tools.my_dataset import RMBDataset
from deepeye.tools.common_tools import set_seed, transform_invert

set_seed(1)  # 设置随机种子

# 参数设置
MAX_EPOCH = 10
BATCH_SIZE = 1
LR = 0.01
log_interval = 10
val_interval = 1
rmb_label = {'1': 0, '100': 1}

class AddPepperNoise(object):
    """增加椒盐噪声
    Args:
        snr （float）: Signal Noise Rate  信噪比
        p (float): 概率值，依概率执行该操作
    """

    def __init__(self, snr, p=0.9):
        assert isinstance(snr, float) and (isinstance(p, float))
        self.snr = snr
        self.p = p

    def __call__(self, img):
        """
        Args:
            img (PIL Image): PIL Image
        Returns:
            PIL Image: PIL image.
        """
        if random.uniform(0, 1) < self.p:
            img_ = np.array(img).copy()
            h, w, c = img_.shape
            signal_pct = self.snr
            noise_pct = (1 - self.snr)
            mask = np.random.choice((0, 1, 2), size=(h, w, 1), p=[signal_pct, noise_pct/2., noise_pct/2.])
            mask = np.repeat(mask, c, axis=2)
            img_[mask == 1] = 255  # 盐噪声
            img_[mask == 2] = 0    # 椒噪声
            return Image.fromarray(img_.astype('unit8')).convert('RGB')
        else:
            return img

# ============================ step 1/5 生成数据 ============================
split_dir = os.path.abspath(os.path.join(BASE_DIR, 'data', 'rmb_split'))
if not os.path.exists(split_dir):
    raise Exception('\n{} 不存在，回到06-dataloader_splitDataset.py生成数据。'.format(split_dir))
train_dir = os.path.join(split_dir, 'train')
valid_dir = os.path.join(split_dir, 'valid')

norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.Resize((224, 224)),   # 输出一个Image

    AddPepperNoise(0.9, p=0.5),   # 接收和输出一个Image

    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

valid_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])


# 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)

# 构建DataLoader
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)

# ============================ step 5/5 训练 ============================
for epoch in range(MAX_EPOCH):
    for i, data in enumerate(train_loader):
        inputs, label = data  # B C H W

        img_tensor = inputs[0, ...]  # C H W
        # transform_invert对transform 进行逆操作，使得可观察到输入数据模型长什么样
        img = transform_invert(img_tensor, train_transform)
        plt.imshow(img)
        plt.show()
        plt.pause(0.5)
        plt.close()

